# coding=utf-8
import os
import numpy as np
import tensorflow as tf
from data import Data
from model import alexnet

import npu_bridge
from tensorflow.core.protobuf.rewriter_config_pb2 import RewriterConfig

config = tf.ConfigProto()
custom_op = config.graph_options.rewrite_options.custom_optimizers.add()
custom_op.name = "NpuOptimizer"
custom_op.parameter_map["dynamic_input"].b = True
custom_op.parameter_map["dynamic_graph_execute_mode"].s = tf.compat.as_bytes("lazy_recompile")
custom_op.parameter_map["use_off_line"].b = True
config.graph_options.rewrite_options.remapping = RewriterConfig.OFF


flags = tf.flags
FLAGS = flags.FLAGS

## Required parameters
flags.DEFINE_string(
    "train_url", "../output",
    "The output directory where the model checkpoints will be written.")

flags.DEFINE_string("data_url", "../dataset",
                    "dataset path")

## Other parameters
flags.DEFINE_integer(
    "num_classes", 5,
    """number of classes for datasets """)

flags.DEFINE_float(
    "learning_rate", 1e-3,
    "The initial learning rate for Adam.")

flags.DEFINE_integer(
    "batch_size", 32,
    "batch size for one NPU")

flags.DEFINE_integer(
    "train_step", 150,
    "total epochs for training")

flags.DEFINE_integer(
    "decay_step", 500,
    "update the learning_rate value every decay_steps times")

flags.DEFINE_float(
    "decay_rate", 0.9,
    "momentum used in optimizer")

def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    tf.logging.info("**********")
    data = Data(batch_size=FLAGS.batch_size, num_classes=FLAGS.num_classes,
                data_path=os.path.join(FLAGS.data_url, "train"),
                val_data=os.path.join(FLAGS.data_url, "val"))
    tf.logging.info("Label dict = %s", data.labels_dict)

    x = tf.placeholder(dtype=tf.float32, shape=[None, 224, 224, 3])
    y = tf.placeholder(dtype=tf.int32, shape=[None, FLAGS.num_classes])
    keep_prob = tf.placeholder(tf.float32)

    # construction model
    pred = alexnet(x, class_num=FLAGS.num_classes, keep_prob=keep_prob)

    # define loss function and optimizer
    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))
    global_step = tf.Variable(0, trainable=False)
    learning_rate = tf.train.exponential_decay(FLAGS.learning_rate, global_step, FLAGS.decay_step, FLAGS.decay_rate, staircase=True)
    optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost,global_step=global_step)

    # definition accuracy
    prediction_correction = tf.equal(tf.cast(tf.argmax(pred, 1), dtype=tf.int32),
                                     tf.cast(tf.argmax(y, 1), dtype=tf.int32), name='prediction')
    accuracy = tf.reduce_mean(tf.cast(prediction_correction, dtype=tf.float32), name='accuracy')

    tf.summary.scalar('loss', cost)
    tf.summary.scalar('accuracy', accuracy)
    summary_op = tf.summary.merge_all()

    # start training
    with tf.Session(config=config) as sess:
        init_op = tf.group(tf.local_variables_initializer(), tf.global_variables_initializer())
        sess.run(init_op)

        train_writer = tf.summary.FileWriter(logdir=os.path.join(FLAGS.train_url, "train"), graph=sess.graph)
        test_writer = tf.summary.FileWriter(logdir=os.path.join(FLAGS.train_url, "test"), graph=sess.graph)

        # saver is used to save the model
        saver = tf.train.Saver()
        max_acc = 0
        for step in range(FLAGS.train_step):
            tf.logging.info(" step = %d", step)
            batch_num = data.get_batch_num()
            for batch_count in range(batch_num):
                train_accuracy_list = []
                train_loss_list = []
                train_images, train_labels = data.get_batch(batch_count)
                sess.run(optimizer, feed_dict={x: train_images, y: train_labels, keep_prob: 0.5})
                train_loss, train_acc, summary = sess.run([cost, accuracy, summary_op], feed_dict={x: train_images, y: train_labels, keep_prob: 1.})
                train_loss_list.append(train_loss)
                train_accuracy_list.append(train_acc)
            train_writer.add_summary(summary, step)
            tf.logging.info("train_acc = %s", np.mean(train_accuracy_list))
            tf.logging.info("train_loss = %s", np.mean(train_loss_list))
            if (step + 1) % 10 == 0:
                test_accuracy_list = []
                test_loss_list = []
                val_images, val_labels = data.get_val_data()
                val_feed = {x: val_images, y: val_labels, keep_prob: 1.}
                test_loss, test_acc, summary = sess.run([cost, accuracy, summary_op], feed_dict=val_feed)
                test_accuracy_list.append(test_acc)
                test_loss_list.append(test_loss)
                test_writer.add_summary(summary, step)
                tf.logging.info("test_acc = %s", test_accuracy_list)
                tf.logging.info("test_loss = %s", test_loss_list)

                # save model
                if test_acc > max_acc:
                    saver.save(sess=sess, save_path=os.path.join(FLAGS.train_url, "model.ckpt"))
                    max_acc = test_acc
        train_writer.close()
        test_writer.close()

if __name__ == "__main__":
    flags.mark_flag_as_required("data_url")
    flags.mark_flag_as_required("train_url")
    tf.app.run()